{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/.local/lib/python3.8/site-packages/malaya/tokenizer.py:202: FutureWarning: Possible nested set at position 3361\n",
      "  self.tok = re.compile(r'({})'.format('|'.join(pipeline)))\n",
      "/home/ubuntu/.local/lib/python3.8/site-packages/malaya/tokenizer.py:202: FutureWarning: Possible nested set at position 3879\n",
      "  self.tok = re.compile(r'({})'.format('|'.join(pipeline)))\n"
     ]
    }
   ],
   "source": [
    "import malaya\n",
    "\n",
    "tokenize = malaya.tokenizer.Tokenizer().tokenize\n",
    "#text = ' '.join(tokenize(text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained('mesolitica/t5-small-standard-bahasa-cased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['finetune-t5-small-standard-bahasa-cased-ner-new/checkpoint-510000',\n",
       " 'finetune-t5-small-standard-bahasa-cased-ner-new/checkpoint-520000',\n",
       " 'finetune-t5-small-standard-bahasa-cased-ner-new/checkpoint-530000']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from glob import glob\n",
    "\n",
    "checkpoints = sorted(glob('finetune-t5-small-standard-bahasa-cased-ner-new/checkpoint-*'))\n",
    "checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = T5ForConditionalGeneration.from_pretrained(checkpoints[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ekstrak nama seseorang dari teks: Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 ( PRU15 ) . Khalid , 79 , yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia ( Putra ) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional ( BN ) , calon Perikatan Nasional ( PN ) Noraffendy Salleh dan calon Pakatan Harapan ( PH ) Norwani Ahmat .\n",
      "Ekstrak organisasi dari teks: Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 ( PRU15 ) . Khalid , 79 , yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia ( Putra ) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional ( BN ) , calon Perikatan Nasional ( PN ) Noraffendy Salleh dan calon Pakatan Harapan ( PH ) Norwani Ahmat .\n",
      "Ekstrak kopi dari teks: Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 ( PRU15 ) . Khalid , 79 , yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia ( Putra ) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional ( BN ) , calon Perikatan Nasional ( PN ) Noraffendy Salleh dan calon Pakatan Harapan ( PH ) Norwani Ahmat .\n",
      "Ekstrak parti politik dari teks: Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 ( PRU15 ) . Khalid , 79 , yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia ( Putra ) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional ( BN ) , calon Perikatan Nasional ( PN ) Noraffendy Salleh dan calon Pakatan Harapan ( PH ) Norwani Ahmat .\n",
      "Ekstrak masa dari teks: Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 ( PRU15 ) . Khalid , 79 , yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia ( Putra ) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional ( BN ) , calon Perikatan Nasional ( PN ) Noraffendy Salleh dan calon Pakatan Harapan ( PH ) Norwani Ahmat .\n",
      "Ekstrak food and beverages dari teks: Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 ( PRU15 ) . Khalid , 79 , yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia ( Putra ) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional ( BN ) , calon Perikatan Nasional ( PN ) Noraffendy Salleh dan calon Pakatan Harapan ( PH ) Norwani Ahmat .\n",
      "Ekstrak age dari teks: Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 ( PRU15 ) . Khalid , 79 , yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia ( Putra ) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional ( BN ) , calon Perikatan Nasional ( PN ) Noraffendy Salleh dan calon Pakatan Harapan ( PH ) Norwani Ahmat .\n",
      "Ekstrak politician dari teks: Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 ( PRU15 ) . Khalid , 79 , yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia ( Putra ) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional ( BN ) , calon Perikatan Nasional ( PN ) Noraffendy Salleh dan calon Pakatan Harapan ( PH ) Norwani Ahmat .\n",
      "Ekstrak ahli politik dari teks: Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 ( PRU15 ) . Khalid , 79 , yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia ( Putra ) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional ( BN ) , calon Perikatan Nasional ( PN ) Noraffendy Salleh dan calon Pakatan Harapan ( PH ) Norwani Ahmat .\n",
      "Ekstrak orang politik dari teks: Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 ( PRU15 ) . Khalid , 79 , yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia ( Putra ) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional ( BN ) , calon Perikatan Nasional ( PN ) Noraffendy Salleh dan calon Pakatan Harapan ( PH ) Norwani Ahmat .\n",
      "Ekstrak parti politik malaysia dari teks: Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 ( PRU15 ) . Khalid , 79 , yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia ( Putra ) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional ( BN ) , calon Perikatan Nasional ( PN ) Noraffendy Salleh dan calon Pakatan Harapan ( PH ) Norwani Ahmat .\n",
      "nama seseorang ['Khalid Yunus', 'Khalid', 'Shamshulkahar Mohd Deli', 'Noraffendy Salleh', 'Norwani Ahmat']\n",
      "organisasi ['Barisan Nasional ( BN )', 'Perikatan Nasional ( PN )', 'Pakatan Harapan ( PH )']\n",
      "kopi ['tiada']\n",
      "parti politik ['tiada']\n",
      "masa ['tiada']\n",
      "food and beverages ['tiada']\n",
      "age ['tiada']\n",
      "politician ['Khalid Yunus', 'Khalid', 'Shamshulkahar Mohd Deli', 'Noraffendy Salleh', 'Norwani Ahmat']\n",
      "ahli politik ['Khalid Yunus', 'Khalid', 'Shamshulkahar Mohd Deli', 'Noraffendy Salleh', 'Norwani Ahmat']\n",
      "orang politik ['Khalid Yunus', 'Khalid', 'Shamshulkahar Mohd Deli', 'Noraffendy Salleh', 'Norwani Ahmat']\n",
      "parti politik malaysia ['tiada']\n"
     ]
    }
   ],
   "source": [
    "from unidecode import unidecode\n",
    "import re\n",
    "\n",
    "# minimum cleaning, just simply to remove newlines.\n",
    "def cleaning(string):\n",
    "    string = string.replace('\\n', ' ')\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip()\n",
    "    return string\n",
    "\n",
    "x = \"\"\"\n",
    "Bekas Ahli Parlimen Jempol empat penggal Khalid Yunus akan membuat penampilan semula pada pilihan raya umum ke-15 (PRU15).\n",
    "\n",
    "Khalid, 79, yang juga timbalan presiden Parti Bumiputera Perkasa Malaysia (Putra) akan bertanding atas tiket Pejuang menentang Shamshulkahar Mohd Deli daripada Barisan Nasional (BN), calon Perikatan Nasional (PN) Noraffendy Salleh dan calon Pakatan Harapan (PH) Norwani Ahmat.\n",
    "\"\"\"\n",
    "\n",
    "text = cleaning(x)\n",
    "text = ' '.join(tokenize(text))\n",
    "tags = ['nama seseorang', 'organisasi', 'kopi', 'parti politik', 'masa', 'food and beverages',\n",
    "       'age', 'politician', 'ahli politik', 'orang politik', 'parti politik malaysia']\n",
    "input_ids = []\n",
    "for l in tags:\n",
    "    s = f'Ekstrak {l} dari teks: {text}'\n",
    "    print(s)\n",
    "    # s = f'teks: {text} entiti: {t}'\n",
    "    input_ids.append({'input_ids': tokenizer.encode(s, return_tensors='pt')[0]})\n",
    "\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "outputs = model.generate(**padded, max_length = 256)\n",
    "b = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "for no, l in enumerate(tags):\n",
    "    print(l, b[no].split(' dan '))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ekstrak person dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "Ekstrak organisasi dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "Ekstrak kopi dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "Ekstrak nama parti politik dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "Ekstrak lokasi dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "Ekstrak syarikat dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "Ekstrak masa dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "Ekstrak quantity dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "Ekstrak kuantiti dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "Ekstrak coffee dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "Ekstrak airport dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "Ekstrak politician dari teks: olly saya nak secawan kopi dengan satu krim dan tiga gula , dan saya sekarang berada di penang airport\n",
      "person tiada\n",
      "organisasi tiada\n",
      "kopi dengan satu krim dan tiga gula\n",
      "nama parti politik tiada\n",
      "lokasi penang airport\n",
      "syarikat tiada\n",
      "masa tiada\n",
      "quantity tiada\n",
      "kuantiti tiada\n",
      "coffee dengan satu krim dan tiga gula\n",
      "airport tiada\n",
      "politician tiada\n"
     ]
    }
   ],
   "source": [
    "text = 'olly saya nak secawan kopi dengan satu krim dan tiga gula, dan saya sekarang berada di penang airport'\n",
    "text = ' '.join(tokenize(text))\n",
    "tags = ['person', 'organisasi', 'kopi', 'nama parti politik', 'lokasi', 'syarikat', 'masa',\n",
    "       'quantity', 'kuantiti', 'coffee', 'airport', 'politician']\n",
    "input_ids = []\n",
    "for l in tags:\n",
    "    s = f'Ekstrak {l} dari teks: {text}'\n",
    "    print(s)\n",
    "    input_ids.append({'input_ids': tokenizer.encode(s, return_tensors='pt')[0]})\n",
    "\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "outputs = model.generate(**padded, max_length = 256)\n",
    "b = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "for no, l in enumerate(tags):\n",
    "    print(l, b[no])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ekstrak person dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak organisasi dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak kopi dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak nama parti politik dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak lokasi dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak syarikat dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak masa dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak quantity dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak kuantiti dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak coffee dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak airport dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak drink dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak fnb dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "Ekstrak makanan dan minuman dari teks: sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks\n",
      "person starbucks\n",
      "organisasi tiada\n",
      "kopi tiada\n",
      "nama parti politik tiada\n",
      "lokasi penang airport\n",
      "syarikat starbucks\n",
      "masa tiada\n",
      "quantity tiada\n",
      "kuantiti tiada\n",
      "coffee tiada\n",
      "airport tiada\n",
      "drink tiada\n",
      "fnb 1 teh o ais\n",
      "makanan dan minuman 1 teh o ais\n"
     ]
    }
   ],
   "source": [
    "text = 'sya nak 1 teh o ais, dan saya sekarang berada di penang airport sambil minum starbucks'\n",
    "tags = ['person', 'organisasi', 'kopi', 'nama parti politik', 'lokasi', 'syarikat', 'masa',\n",
    "       'quantity', 'kuantiti', 'coffee', 'airport', 'drink', 'fnb', 'makanan dan minuman']\n",
    "input_ids = []\n",
    "for l in tags:\n",
    "    s = f'Ekstrak {l} dari teks: {text}'\n",
    "    print(s)\n",
    "    input_ids.append({'input_ids': tokenizer.encode(s, return_tensors='pt')[0]})\n",
    "\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "outputs = model.generate(**padded, max_length = 256)\n",
    "b = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "for no, l in enumerate(tags):\n",
    "    print(l, b[no])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ekstrak year dari teks: The Square is located by the historic Harvard Yard where you find buildings dating back as far as the early 18th century .\n",
      "Ekstrak person dari teks: The Square is located by the historic Harvard Yard where you find buildings dating back as far as the early 18th century .\n",
      "Ekstrak orang dari teks: The Square is located by the historic Harvard Yard where you find buildings dating back as far as the early 18th century .\n",
      "Ekstrak org dari teks: The Square is located by the historic Harvard Yard where you find buildings dating back as far as the early 18th century .\n",
      "Ekstrak people dari teks: The Square is located by the historic Harvard Yard where you find buildings dating back as far as the early 18th century .\n",
      "Ekstrak tahun dari teks: The Square is located by the historic Harvard Yard where you find buildings dating back as far as the early 18th century .\n",
      "year tiada\n",
      "person tiada\n",
      "orang tiada\n",
      "org Harvard Yard\n",
      "people tiada\n",
      "tahun tiada\n"
     ]
    }
   ],
   "source": [
    "text = 'The Square is located by the historic Harvard Yard where you find buildings dating back as far as the early 18th century .'\n",
    "tags = ['year', 'person', 'orang', 'org', 'people', 'tahun']\n",
    "input_ids = []\n",
    "for l in tags:\n",
    "    s = f'Ekstrak {l} dari teks: {text}'\n",
    "    print(s)\n",
    "    input_ids.append({'input_ids': tokenizer.encode(s, return_tensors='pt')[0]})\n",
    "\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "outputs = model.generate(**padded, max_length = 256)\n",
    "b = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "for no, l in enumerate(tags):\n",
    "    print(l, b[no])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Cloning https://huggingface.co/mesolitica/finetune-zeroshot-ner-t5-small-standard-bahasa-cased into local empty directory.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e23d094c15da4a488540ce1f3064bfba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Download file pytorch_model.bin:   0%|          | 8.74k/231M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "66c50440ecad4bdba81d7d0ea1bcaf03",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Download file spiece.model:   0%|          | 1.43k/784k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c46ae0020dcf49e1ae48ecc44c8c809d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Clean file spiece.model:   0%|          | 1.00k/784k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5a244aa42a6b477f8a2c76af39d023bd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Clean file pytorch_model.bin:   0%|          | 1.00k/231M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "83fc1cb517e74fc2a397b3b400559465",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file pytorch_model.bin:   0%|          | 4.00k/231M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "remote: Scanning LFS files for validity, may be slow...        \n",
      "remote: LFS file scan complete.        \n",
      "To https://huggingface.co/mesolitica/finetune-zeroshot-ner-t5-small-standard-bahasa-cased\n",
      "   7fba56e..a4ee1b8  main -> main\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/mesolitica/finetune-zeroshot-ner-t5-small-standard-bahasa-cased/commit/a4ee1b893e38bcfec52fa940d41c83acb16daa43'"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.push_to_hub('finetune-zeroshot-ner-t5-small-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.push_to_hub('finetune-zeroshot-ner-t5-small-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import collections\n",
    "import string\n",
    "\n",
    "def normalize_answer(s):\n",
    "    \"\"\"Lower text and remove punctuation, articles and extra whitespace.\"\"\"\n",
    "    def remove_articles(text):\n",
    "        regex = re.compile(r'\\b(a|an|the)\\b', re.UNICODE)\n",
    "        return re.sub(regex, ' ', text)\n",
    "\n",
    "    def white_space_fix(text):\n",
    "        return ' '.join(text.split())\n",
    "\n",
    "    def remove_punc(text):\n",
    "        exclude = set(string.punctuation)\n",
    "        return ''.join(ch for ch in text if ch not in exclude)\n",
    "\n",
    "    def lower(text):\n",
    "        return text.lower()\n",
    "    return white_space_fix(remove_articles(remove_punc(lower(s))))\n",
    "\n",
    "\n",
    "def get_tokens(s):\n",
    "    if not s:\n",
    "        return []\n",
    "    return normalize_answer(s).split()\n",
    "\n",
    "\n",
    "def compute_exact(a_gold, a_pred):\n",
    "    return int(normalize_answer(a_gold) == normalize_answer(a_pred))\n",
    "\n",
    "\n",
    "def compute_f1(a_gold, a_pred):\n",
    "    gold_toks = get_tokens(a_gold)\n",
    "    pred_toks = get_tokens(a_pred)\n",
    "    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)\n",
    "    num_same = sum(common.values())\n",
    "    if len(gold_toks) == 0 or len(pred_toks) == 0:\n",
    "        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise\n",
    "        return int(gold_toks == pred_toks)\n",
    "    if num_same == 0:\n",
    "        return 0\n",
    "    precision = 1.0 * num_same / len(pred_toks)\n",
    "    recall = 1.0 * num_same / len(gold_toks)\n",
    "    f1 = (2 * precision * recall) / (precision + recall)\n",
    "    return f1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "170139it [1:55:11, 24.62it/s]\n"
     ]
    }
   ],
   "source": [
    "f1, exact = [], []\n",
    "\n",
    "with open('shuffled-test.json') as fopen:\n",
    "    for l in tqdm(fopen):\n",
    "        data = json.loads(l)['translation']\n",
    "        input_ids = [{'input_ids': tokenizer.encode(data['prefix'] + data['src'], return_tensors='pt')[0]}]\n",
    "        padded = tokenizer.pad(input_ids, padding='longest')\n",
    "        outputs = model.generate(**padded, max_length = 256)\n",
    "        a = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n",
    "        b = data['tgt']\n",
    "        f1.append(compute_f1(a, b))\n",
    "        exact.append(compute_exact(a, b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9512811620321475, 0.9208940924773273)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "np.mean(f1), np.mean(exact)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
